      
import numpy as np
import torch as th
import random
import os
import yaml
import omnigibson as og
import omnigibson.lazy as lazy
from omnigibson.macros import gm
from omnigibson.objects import REGISTERED_OBJECTS
from omnigibson.utils.python_utils import create_class_from_registry_and_config
from omnigibson.utils.constants import PrimType

class EnvAugmentor:
    """环境增强器：用于随机改变环境状态"""
    
    def __init__(self, default_obj_config_path, default_texture_path, domain=None):
        self.log = dict()
        self.default_obj_config_path = default_obj_config_path
        self.default_texture_path = default_texture_path
        self.domain = domain
        self.registered_objects = ["ball0"]

    # def change_obj_pose(self, env, position_offset):
    #     ori_pos = None
    #     for obj in env.scene.objects:
    #         if obj.name == 'ball0':
    #             ori_pos = obj.get_position()
    #             break

    #     if ori_pos is None:
    #         raise ValueError("Could not find ball0 in environment")
        
        
    #     new_pos=[position_offset[0]+ori_pos[0], position_offset[1]+ori_pos[1], ori_pos[2]]
    #     ball = env.scene.object_registry("name", "ball0")
    #     ball.set_position(new_pos)
    #     self.log['change_obj_pose'] = new_pos

    def change_density(self, env, density):
        obj = env.scene.object_registry("name", "cup0")
        if obj._prim_type == PrimType.RIGID:
            for link in obj._links.values():
                if link.has_collision_meshes:
                    print(link.mass)
                    input()
                    link.mass = 0.0
                    link.density = density
                    print(link.mass)
                    input()

        elif obj._prim_type == PrimType.CLOTH:
            obj.root_link.mass = density * obj.root_link.volume
        
        self.log['change_density'] = density
    def change_cup_texture(self, env, texture_path):
        obj = env.scene.object_registry("name", "ball0")
        # 查找包含"material"的子Prim
        def find_material_prims(prim, path=""):
            material_prims = []
            for child in prim.GetChildren():
                child_name = child.GetName()
               
                
                # 如果名称中包含"material"
                if "material" in child_name.lower():
                    material_prims.append(child)
                
                # 递归查找子Prim
                material_prims.extend(find_material_prims(child))
            
            return material_prims
        
        # 从geometry开始查找
        geometry_prim = obj._prim.GetChild("geometry") if obj._prim.GetChild("geometry") else obj._prim
        material_prims = find_material_prims(geometry_prim)
        
        # 尝试设置纹理
        texture_set = False
        for material_prim in material_prims:
            if material_prim.GetAttribute("inputs:diffuse_texture"):
                material_prim.GetAttribute("inputs:diffuse_texture").Set(texture_path)
    def change_mass(self, env, mass):
        obj = env.scene.object_registry("name", "cup0")
        if obj._prim_type == PrimType.RIGID:
            for link in obj._links.values():
                if link.has_collision_meshes:
                    link.mass = mass
        
        self.log['change_mass'] = mass

    def change_size(self, env, scale):
        obj = env.scene.object_registry("name", "cup0")
        og.sim.stop()
        if scale[1] == 1:
            obj.scale = scale[0] * obj.scale
        else:
            print(obj.scale)
            input()
            obj.scale = th.tensor([scale[0] * obj.scale[0],scale[0] * obj.scale[1],obj.scale[2]])
            print(obj.scale)
            input()
        og.sim.play()
        self.log['change_scale'] = scale

    def change_obj_pose(self, env, position_offset):
        for obj_name in self.registered_objects:
            ori_pos = None
            for obj in env.scene.objects:
                if obj.name == obj_name:
                    print(obj.name)
                    ori_pos = obj.get_position()
                    break  
            if ori_pos is None:
                raise ValueError(f"Could not find {obj_name} in environment")
            
            new_pos=[position_offset[0]+ori_pos[0], position_offset[1]+ori_pos[1], ori_pos[2]]
            obj = env.scene.object_registry("name", obj_name)
            if len(position_offset) == 4 and obj_name == "plate0":
                new_pos = [new_pos[0]+position_offset[2],new_pos[1]+position_offset[3],new_pos[2]]
            obj.reset()
            obj.set_linear_velocity(velocity=th.zeros(3))
            obj.set_angular_velocity(velocity=th.zeros(3))
            obj.set_position(new_pos)
            self.log['change_obj_pose'] = position_offset

    def adjust_light_intensity(self, env, light_intensity):
        """随机调整光照强度"""
        scale = light_intensity
        for obj_idx in env.scene._init_objs.keys():
            obj = env.scene._init_objs[obj_idx]
            self._recursive_light_update(obj._prim, 'intensity', scale)
        self.log['adjust_light_intensity'] = scale
    
    def adjust_light_color(self, env, light_color):
        """随机改变光照颜色"""
        color = light_color
        for obj_idx in env.scene._init_objs.keys():
            obj = env.scene._init_objs[obj_idx]
            self._recursive_light_update(obj._prim, 'color', color)
        self.log['adjust_light_color'] = color


    def adjust_light_color_random(self, env):
        for obj_idx in env.scene._init_objs.keys():
            obj = env.scene._init_objs[obj_idx]
            def recursive_light_update(child_prim):
                if "Light" in child_prim.GetPrimTypeInfo().GetTypeName():
                    child_prim.GetAttribute("inputs:color").Set(lazy.pxr.Gf.Vec3f(random.choice(self.light_color))) 
                for child_child_prim in child_prim.GetChildren():
                    recursive_light_update(child_child_prim)

            recursive_light_update(obj._prim)

    def replace_object(self, env, obj_configs):
        """随机替换物体"""
        # 移除现有物体
        info = []
        
        for obj_config in obj_configs:
            obj = env.scene.object_registry("name", obj_config['name'])
            # og.sim.stop()
            env.scene.remove_object(obj=obj)
        
        # 加载新物体配置
        # print(obj_config['name'])
        # temp = obj_config['scale']
        # if np.random.rand() < 0.333:
        #     obj_config['scale'] = [temp[0],temp[1],temp[2]] 
        # else:
        #     obj_config['scale'] = [temp[0]*1.1,temp[1]*1.1,temp[2]*1.1] if np.random.rand() < 0.5 else [temp[0]*0.9,temp[1]*0.9,temp[2]*0.9]
        # # 创建并添加新物体
        for obj_config in obj_configs:
            obj = create_class_from_registry_and_config(
                cls_name=obj_config["type"],
                cls_registry=REGISTERED_OBJECTS,
                cfg=obj_config,
                cls_type_descriptor="object",
            )
            env.scene.add_object(obj)
            
            # 设置位置和方向
            position = obj_config.pop("position", None)
            orientation = obj_config.pop("orientation", None)
            # og.sim.play()
            obj.reset()
            obj.set_linear_velocity(velocity=th.zeros(3))
            obj.set_angular_velocity(velocity=th.zeros(3))
            obj.set_position_orientation(position=position, orientation=orientation, frame="scene")
            
            if 'category' in obj_config.keys():
                info.append(f'''{obj_config['category']}_{obj_config['model']}''')
            elif 'usd_path' in obj_config.keys():
                info.append(f'''{obj_config['usd_path']}''')
            
        og.sim.step()                    
        self.log['replace_object'] = tuple(info)

    def add_plate(self, env):
         with open(self.default_obj_config_path, "r") as f:
            cfg = yaml.load(f, Loader=yaml.FullLoader)
            obj_config = cfg["objects"][1]
        
        # 创建并添加新物体
            obj = create_class_from_registry_and_config(
                cls_name=obj_config["type"],
                cls_registry=REGISTERED_OBJECTS,
                cfg=obj_config,
                cls_type_descriptor="object",
            )
            env.scene.add_object(obj)
            position = obj_config.pop("position", None)
            orientation = obj_config.pop("orientation", None)
            obj.set_position_orientation(position=position, orientation=orientation, frame="scene")
            self.log['add_plate'] = f'''{obj_config['category']}_{obj_config['model']}'''
    
    def change_texture(self, env, texture_path):
        """随机改变纹理"""
        table = env.scene.object_registry("name", "mytable")
        table_prim = table._prim
        self._get_texture_prim(table_prim).GetAttribute("inputs:file").Set(texture_path)
        self.log['change_texture'] = texture_path
    
    def _recursive_light_update(self, prim, attr_type, value):
        """递归更新光照属性"""
        if "Light" in prim.GetPrimTypeInfo().GetTypeName():
            if attr_type == 'intensity':
                prim.GetAttribute("inputs:intensity").Set(gm.FORCE_LIGHT_INTENSITY * value)
            elif attr_type == 'color':
                prim.GetAttribute("inputs:color").Set(lazy.pxr.Gf.Vec3f(value))
        for child_prim in prim.GetChildren():
            self._recursive_light_update(child_prim, attr_type, value)
    
    #zkf
    def _get_texture_prim(self, table_prim):
        """获取纹理采样器"""
        return (table_prim
                .GetChild("geometry")
                .GetChild("collision")
                .GetChild("boardMat")
                .GetChild("diffuseTexture"))
    #xyh
    # def _get_texture_prim(self, table_prim):
    #     """获取纹理采样器"""
    #     return (table_prim
    #             .GetChild("Looks")
    #             .GetChild("boardMat")
    #             .GetChild("diffuseTexture"))

    
    
    def apply_random_aug(self, env):
        """随机应用一种增强"""
        aug_funcs = [
            self.replace_object,
            self.apply_object_pos,
            self.change_obj_pose,
            self.adjust_light_intensity,
            self.adjust_light_color,
            self.change_texture,
            # self.adjust_light_color_random,
            # self.add_plate
        ]
        random.choice(aug_funcs)(env)
    def apply_object_pos(self, env):
        """随机应用两种增强"""
        aug_funcs = [
            self.replace_object,
            self.change_obj_pose,
            # self.adjust_light_intensity,
            # self.adjust_light_color,
            # self.change_texture
        ]
        selected_funcs = random.sample(list(enumerate(aug_funcs)), 2)
        selected_funcs.sort(key=lambda x: x[0])
        for _, func in selected_funcs:
            func(env)
    def apply_two_aug(self, env):
        """随机应用两种增强"""
        aug_funcs = [
            self.replace_object,
            self.change_obj_pose,
            self.adjust_light_intensity,
            self.adjust_light_color,
            self.change_texture
        ]
        selected_funcs = random.sample(list(enumerate(aug_funcs)), 2)
        selected_funcs.sort(key=lambda x: x[0])
        for _, func in selected_funcs:
            func(env)
    def apply_three_aug(self, env):
        """随机应用三种增强"""
        aug_funcs = [
            self.replace_object,
            self.change_obj_pose,
            self.adjust_light_intensity,
            self.adjust_light_color,
            self.change_texture
        ]
        selected_funcs = random.sample(list(enumerate(aug_funcs)), 3)
        selected_funcs.sort(key=lambda x: x[0])
        for _, func in selected_funcs:
            func(env)
    def apply_random_n_aug(self, env, min_augs=1, max_augs=3):
        """随机应用1到3种增强
        
        Args:
            env: 要应用增强的环境
            min_augs: 最少应用的增强数量，默认为1
            max_augs: 最多应用的增强数量，默认为3
        """
        aug_funcs = [
            self.replace_object,
            self.change_obj_pose,
            self.adjust_light_intensity,
            self.adjust_light_color,
            self.change_texture
        ]
        
        # 随机决定应用几种增强
        num_augs = random.randint(min_augs, max_augs)
        
        # 随机选择指定数量的增强函数
        selected_funcs = random.sample(list(enumerate(aug_funcs)), num_augs)
        
        # 按原始顺序排序并应用
        selected_funcs.sort(key=lambda x: x[0])
        for _, func in selected_funcs:
            func(env)
    def restore2defaultstate(self, env):
        """恢复到默认状态"""
        info = None
        table = env.scene.object_registry("name", "mytable")
        table_prim = table._prim
        self._get_texture_prim(table_prim).GetAttribute("inputs:file").Set(self.default_texture_path)
        for obj_idx in env.scene._init_objs.keys():
            obj = env.scene._init_objs[obj_idx]
            self._recursive_light_update(obj._prim, 'intensity', 1.0)
            self._recursive_light_update(obj._prim, 'color', [1.0, 1.0, 1.0])
        
        # has_plate = False
        # for obj in env.scene.objects:
        #     if obj.name == 'plate0':
        #         has_plate = True
        #         break

        # if has_plate:
        #     plate = env.scene.object_registry("name", "plate0")
        #     env.scene.remove_object(obj=plate)
       
        for obj_name in self.registered_objects:
            obj = env.scene.object_registry("name", obj_name)
            env.scene.remove_object(obj=obj)
        
        # 加载新物体配置
        with open(self.default_obj_config_path, "r") as f:
            cfg = yaml.load(f, Loader=yaml.FullLoader)
        og.sim.stop()
        for obj_config in cfg["objects"]:
            # obj_config = cfg["objects"][0]
            
            # 创建并添加新物体
            obj = create_class_from_registry_and_config(
                cls_name=obj_config["type"],
                cls_registry=REGISTERED_OBJECTS,
                cfg=obj_config,
                cls_type_descriptor="object",
            )
            env.scene.add_object(obj)
            if obj_config['name'] == 'ball0':
                if 'category' in obj_config.keys():
                    info = f'''{obj_config['category']}_{obj_config['model']}'''    
                elif 'usd_path' in obj_config.keys():
                    info = f'''{obj_config['usd_path']}'''
            # if obj_config['name'] == 'cup0':
            #     obj.scale = th.tensor()
            
            # 设置位置和方向
            position = obj_config.pop("position", None)
            orientation = obj_config.pop("orientation", None)
            obj.set_position_orientation(position=position, orientation=orientation, frame="scene")
        if 'cup_texture_path' in cfg.keys():
            self.change_cup_texture(env, cfg['cup_texture_path'])
        og.sim.play()
        og.sim.step()
        # env.post_play_load()    
        self.log = dict()
        self.log['default_obj_config_path'] = self.default_obj_config_path
        return info

    def iterate_env_aug(self, env, iter_cfg):
        top_level_keys = list(iter_cfg.keys())
        if 'obj_config' in top_level_keys:
            self.replace_object(env, iter_cfg['obj_config'])
        if 'position_offset' in top_level_keys:
            self.change_obj_pose(env, iter_cfg['position_offset'])
        if 'light_color' in top_level_keys:
            self.adjust_light_color(env, iter_cfg['light_color'])
        if 'light_intensity' in top_level_keys:
            self.adjust_light_intensity(env, iter_cfg['light_intensity'])
        if 'texture_path' in top_level_keys:
            self.change_texture(env, iter_cfg['texture_path'])
        if 'mass' in top_level_keys:
            self.change_mass(env, iter_cfg['mass'])
        if 'scale' in top_level_keys:
            self.change_size(env, iter_cfg['scale'])
        if 'cup_texture_path' in top_level_keys:
            self.change_cup_texture(env, iter_cfg['cup_texture_path'])

    